# encoding=utf-8
"""
@author: xiao nian
@contact: xiaonian030@163.com
@time: 2021-12-12 15:30
"""
from sklearn.model_selection import train_test_split
from config.config import MODEL_CONFIG


def data():

    # 预料
    text_list = []
    label_list = []

    # 读取数据 and 数据处理
    with open('data/train.txt', mode='r', encoding='utf-8') as f:
        for line in f.readlines():
            try:
                cols = line.split('_!_')
                text_item = str(cols[3]).strip()
                label_item = str(cols[2]).strip()
                text_item = text_item.replace("\n", "")
                label_item = label_item.replace("\n", "")
            except:
                text_item = ''
                label_item = ''
            if text_item == '' or label_item == '':
                continue
            text_list.append(text_item)
            if MODEL_CONFIG['multi_label']:
                # 多分类
                label_list.append([label_item, ])
            else:
                label_list.append(label_item)
    # 准备训练、评估、测试数据集
    remain_x, train_x, remain_y, train_y = train_test_split(text_list, label_list, test_size=0.7, random_state=42)
    valid_x, test_x, valid_y, test_y = train_test_split(remain_x, remain_y, test_size=0.5, random_state=42)
    return train_x, train_y, valid_x, valid_y, test_x, test_y
